/*
 * Copyright (c) 2010-2020. Axon Framework
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.axonframework.spring.eventsourcing;

import org.axonframework.config.Configuration;
import org.axonframework.eventhandling.DomainEventMessage;
import org.axonframework.eventsourcing.AbstractAggregateFactory;
import org.axonframework.eventsourcing.AggregateFactory;
import org.axonframework.eventsourcing.IncompatibleAggregateException;
import org.axonframework.modelling.command.inspection.AggregateModel;
import org.axonframework.modelling.command.inspection.AnnotatedAggregateMetaModelFactory;
import org.springframework.beans.factory.BeanNameAware;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

import static java.lang.String.format;

/**
 * AggregateFactory implementation that uses Spring prototype beans to create new uninitialized instances of
 * Aggregates.
 *
 * @param <T> The type of aggregate generated by this aggregate factory
 * @author Allard Buijze
 * @since 1.2
 */
public class SpringPrototypeAggregateFactory<T> implements AggregateFactory<T>, InitializingBean,
                                                           ApplicationContextAware, BeanNameAware {

    private final String prototypeBeanName;
    private ApplicationContext applicationContext;
    private String beanName;
    private Class<T> aggregateType;
    private final Map<Class<? extends T>, String> subtypes;
    private AggregateFactory<T> delegate;

    /**
     * Initializes the factory to create beans instances for the bean with given {@code prototypeBeanName}.
     * <p>
     * Note that the the bean should have the prototype scope.
     *
     * @param prototypeBeanName the name of the prototype bean this repository serves.
     */
    public SpringPrototypeAggregateFactory(String prototypeBeanName) {
        this(prototypeBeanName, new HashMap<>());
    }

    /**
     * Initializes the factory to create beans instances for the bean with given {@code prototypeBeanName} and its
     * {@code subtypes}.
     * <p>
     * Note that the the bean should have the prototype scope.
     *
     * @param prototypeBeanName the name of the prototype bean this repository serves.
     * @param subtypes          the map of subtype of this aggregate to its spring prototype name
     */
    public SpringPrototypeAggregateFactory(String prototypeBeanName, Map<Class<? extends T>, String> subtypes) {
        this.prototypeBeanName = prototypeBeanName;
        this.subtypes = subtypes;
    }

    @Override
    public T createAggregateRoot(String aggregateIdentifier, DomainEventMessage<?> firstEvent) {
        return delegate.createAggregateRoot(aggregateIdentifier, firstEvent);
    }

    @SuppressWarnings("unchecked")
    @Override
    public Class<T> getAggregateType() {
        if (aggregateType == null) {
            aggregateType = (Class<T>) applicationContext.getType(prototypeBeanName);
        }
        return aggregateType;
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) {
        this.applicationContext = applicationContext;
        this.subtypes.put(getAggregateType(), prototypeBeanName);
    }

    @Override
    public void setBeanName(String beanName) {
        this.beanName = beanName;
    }

    @SuppressWarnings("unchecked")
    @Override
    public void afterPropertiesSet() {
        if (!applicationContext.isPrototype(prototypeBeanName)) {
            throw new IncompatibleAggregateException(
                    format("Cannot initialize repository '%s'. "
                                   + "The bean with name '%s' does not have the 'prototype' scope.",
                           beanName, prototypeBeanName));
        }
        AggregateModel<T> model;
        if (applicationContext.getBeanNamesForType(Configuration.class).length > 0) {
            Configuration configuration = applicationContext.getBean(Configuration.class);
            model = AnnotatedAggregateMetaModelFactory.inspectAggregate(getAggregateType(),
                                                                        configuration.parameterResolverFactory(),
                                                                        configuration
                                                                                .handlerDefinition(getAggregateType()),
                                                                        subtypes.keySet());
        } else {
            model = AnnotatedAggregateMetaModelFactory.inspectAggregate(getAggregateType(),
                                                                        subtypes.keySet());
        }
        this.delegate = new AbstractAggregateFactory<T>(model) {
            @Override
            protected T doCreateAggregate(String aggregateIdentifier, DomainEventMessage firstEvent) {
                return (T) applicationContext.getBean(prototype(firstEvent.getType()));
            }

            private String prototype(String aggregateType) {
                return aggregateModel().type(aggregateType)
                                       .map(subtypes::get)
                                       // for backwards compatibility, in cases where firstEvent does not contain
                                       // the aggregate type
                                       .orElse(prototypeBeanName);
            }

            @Override
            protected T postProcessInstance(T aggregate) {
                applicationContext.getAutowireCapableBeanFactory()
                                  .configureBean(aggregate, subtypes.get(aggregate.getClass()));
                return aggregate;
            }
        };
    }
}
